import numpy as np

import matplotlib as mpl
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

import tempfile
import subprocess

mpl.use('pgf')
mpl.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
    'pgf.preamble': r'\usepackage{times}'
})
from matplotlib import pyplot as plt

import mpmath
from mpmath import mp, mpf
mp.dps = 1000


def game_matrix(delta):
    """Returns the utility matrix for player 1."""
    return -np.array([[mpf(0.5+delta), mpf(1/2)],[mpf(0), mpf(1)]])

def simulate(A, eta, prox_fn, T, omd=True):
    x = np.ones(2) / mpf(2)
    y = np.ones(2) / mpf(2)

    # Last gradient
    gx = np.zeros(2) * mpf(0)
    gy = np.zeros(2) * mpf(0)
    
    # Sum of losses
    Lx = np.zeros(2) * mpf(0)
    Ly = np.zeros(2) * mpf(0)

    xs   = []   # List of x[0] 
    ys   = []   # List of y[0]
    gaps = []   # List of Nash equilibrium gap
    ts   = []   # List of times

    for t in range(T):
        gx_, gy_ = A @ y, -A.T @ x
        Lx += gx_
        Ly += gy_

        xs.append(x[0])
        ys.append(y[0])
        ts.append(t)
        gaps.append(gx_.max() + gy_.max())

        if omd:
            # Use OMD update
            x = prox_fn(2 * gx_ - gx, x, eta)
            y = prox_fn(2 * gy_ - gy, y, eta)
        else:
            # Use FTRL update
            x = prox_fn(Lx + gx_, np.ones(2), eta)
            y = prox_fn(Ly + gy_, np.ones(2), eta)
        assert (x >= 0).all() and abs(x.sum() - 1.0) < 1e-10
        assert (y >= 0).all() and abs(y.sum() - 1.0) < 1e-10

        gx, gy = gx_, gy_

    return xs, ys, gaps, ts 

def entropy_prox(g, x0, eta):
    """
    Returns the solution to argmax <g, x> - 1/eta KL(x, x0). 
    """
    ans = x0 * np.array([
        mp.exp(eta * (g[0] - g.max())),
        mp.exp(eta * (g[1] - g.max()))
    ])
    ans /= ans.sum()
    return ans

def log_prox(g, x0, eta):
    """
    Returns the solution to argmax <g, x> - 1/eta D(x, x0),
    where D(x, x0) is the Bregman divergence associated with the log regularizer.

    In particular, this is the solution to

    argmax (eta g[0] + 1/x0[0]) x[0] + (eta g[1] + 1/x0[1]) x[1] + log(x[0]) + log(x[1])

    Letting x[0] = a, we have the 1D optimization problem
    argmax (eta g[0] - eta g[1] + 1/x0[0] - 1/x0[1]) a + log(a) + log(1-a).

    We compute the solution to this problem in closed form.
    """
    k = -(eta * (g[0] - g[1]) + 1/x0[0] - 1/x0[1])

    if k == 0:
        a = mpf(0.5)
    elif k > 0:
        a = (1/k + 0.5) - mp.sqrt(0.25 + 1 / (k*k))
    else:
        a = (1/k + 0.5) + mp.sqrt(0.25 + 1 / (k*k))
    assert a > 0 and a < 1
    return np.array([a, 1-a])

def euclidean_prox(g, x0, eta):
    """
    Returns the solution to argmax <g, x> - 1/(2eta) ||x - x0||^2.

    This is the solution to

    argmax (2eta g[0] + 2 x0[0]) x[0] + (2eta g[1] + 2 x0[1]) x[1] - x[0]^2 - x[1]^2
    = argmax (2 eta g[0] + 2 eta g[1] - 2 x0[0] - 2 x0[1]) a - 2 a^2 + 2 a
    = argmax (2 + 2 eta g[0] - 2 eta g[1] + 2 x0[0] - 2 x0[1]) a - 2 a^2
    = clip( (2 + 2 eta g[0] - 2 eta g[1] + 2 x0[0] - 2 x0[1]) / 4 ).
    """

    k = 1/2 + (eta * (g[0] - g[1]) + (x0[0] - x0[1])) / 2
    a = np.clip(k, 0, 1)
    return np.array([a, 1-a])

def filter_data(xs, ys, gaps, ts, tol=0.003, en=lambda x: x):
    xs_filt = [xs[0]]
    ys_filt = [ys[0]]
    gaps_filt = [gaps[0]]
    ts_filt = [ts[0]]

    xrange = max([en(x) for x in xs]) - min([en(x) for x in xs])
    print("X range:", xrange)
    for (x, y, gap, t) in zip(xs, ys, gaps, ts):
        if abs(en(x)-en(xs_filt[-1]))/xrange + abs(y-ys_filt[-1]) > tol:
            xs_filt.append(x)
            ys_filt.append(y)
            gaps_filt.append(gap)
            ts_filt.append(t)
    print("Filtered out", len(xs) - len(xs_filt), "points out of", len(xs), "total.")

    return xs_filt, ys_filt, gaps_filt, ts_filt

def make_plot(outfile, title, delta, eta, prox_fn, T=1000, omd=True, energy_plot=False, ylabel=False, inset=None):
    A = game_matrix(delta)
    xs, ys, gaps, ts = simulate(A, eta, prox_fn, T=T, omd=omd)

    for (t,x) in zip(ts, xs):
        if x >= 1/(1+delta):
            T1 = t
            break
    assert T1 is not None
    assert ts[T1] == T1

    for (t, y) in zip(ts, ys):
        if t > T1 and y >= 0.5/(1+delta):
            T2 = t
            break
    assert T2 is not None
    assert ts[T2] == T2

    T3 = None
    for (t, gap) in zip(ts, gaps):
        if t > T2 and y >= gap >= 0.1:
            T3 = t
            break
    # assert T3 is not None
    # assert ts[T3] == T3

    print("T1 is", T1, "T2 is", T2)

    fig, ((ax1, ax2)) = plt.subplots(2, 1, figsize=(.84*2.5 + (0.1 if ylabel else 0.0),.9*4.8))

    if energy_plot:
        en = lambda x: mpmath.log(x) - mpmath.log(1-x)
        ax1.set_xlabel(r"\!$\log(x^t[1]) - \log(1-x^t[1])$")
        if ylabel:
            ax1.set_ylabel("$y^t[1]$")
    else:
        en = lambda x: x
        ax1.set_xlabel("$x^t[1]$")
        if ylabel:
            ax1.set_ylabel("$y^t[1]$")

    # Filter so that we do not have too many points in the plot.
    xs_filt, ys_filt, gaps_filt, ts_filt = filter_data(xs, ys, gaps, ts, en=en)
    
    ax1.plot([en(x) for x in xs_filt], [y for y in ys_filt], 'o-', ms=.8, lw=.9)
    if inset == 1:
        axins = inset_axes(ax1, .4,.4, loc=2, bbox_to_anchor=(.595, 0.86),bbox_transform=ax1.figure.transFigure)
        axins.set_ylim((0.5-5e-2, 0.5+5e-2))
        axins_xs = []
        axins_ys = []
        for (x, y) in zip(xs_filt, ys_filt):
            axins_xs.append(-mp.log10(1-x))
            axins_ys.append(y)
        print(axins_xs[:3])
        axins_xs = np.array(axins_xs)
        axins.plot(axins_xs[::3], axins_ys[::3], 'o-', ms=.3, lw=.5)
        axins.set_xlim((10, None))
        axins.tick_params(left = False, labelleft = False, bottom=False, labelbottom=False)
        # mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5")
    elif inset == 2:
        axins = inset_axes(ax1, .4,.4, loc=2, bbox_to_anchor=(.57, 0.86),bbox_transform=ax1.figure.transFigure)
        axins.set_xlim((0.978, 1.002))
        axins.set_ylim((0.484, 0.508))
        axins_xs = []
        axins_ys = []
        axins.plot(xs_filt, ys_filt, 'o-', ms=.3, lw=.5)
        axins.tick_params(left = False, labelleft = False, bottom=False, labelbottom=False)
        # mark_inset(ax1, axins, loc1=1, loc2=4, fc="none", ec="0.5")
    if inset:
        ax1.add_patch(patches.Rectangle((0.97, 0.45), 0.04, 0.08, fill=False, edgecolor='brown', lw=.5, zorder=1))
        axins.spines['left'].set_color('brown')
        axins.spines['right'].set_color('brown')
        axins.spines['top'].set_color('brown')
        axins.spines['bottom'].set_color('brown')
        ax1.plot([0.97, .765], [0.45, 0.485], c='brown', lw=.5)
        ax1.plot([0.97+.04, .924], [0.45+.08, 0.770], c='brown', lw=.5)

    # axins.tick_params('x', length=2, width=.5, which='major', labelsize=6)
    # axins.xaxis.set_major_formatter('1-10$^{{\\textrm{{-{x}}}}}$')
    # axins.ticks('x', fontsize='tiny')

    # ax1.semilogx([1-x for x in xs], [y for y in ys], 'o-', ms=1, lw=1)
    ax1.plot([en(xs[T1])], [ys[T1]], 'r*', ms=7, mew=.5, mec='black')
    ax1.plot([en(xs[T2])], [ys[T2]], 'o', ms=5.5, mew=.5, mec='black', mfc='blue')
    if T3:
        ax1.plot([en(xs[T3])], [ys[T3]], 'gs', ms=5, mew=.5, mec='black',)
    ax1.set_ylim(0, 1)
    # ax1.set_xlim(.5, 1)
    ax1.grid()
        
    ax2.loglog(ts_filt, gaps_filt, 'o-', ms=1, lw=1)
    ax2.loglog([T1], [gaps[T1]], 'r*', ms=7, mew=.5, mec='black')
    ax2.loglog([T2], [gaps[T2]], 'o', ms=5.5, mew=.5, mec='black', mfc='blue')
    if T3:
        ax2.plot([T3], [gaps[T3]], 'gs', ms=5, mew=.5, mec='black',)
    ax2.grid()
    ax2.set_xlabel("Iteration")
    if ylabel:
        ax2.set_ylabel("Equilibrium gap")
    # ax2.set_ylim(1e-3, None)
    # fig.suptitle(f"Entropic regularizer ;; delta={delta}, eta={eta}")
    ax1.set_title(title)
    fig.tight_layout()
    with tempfile.TemporaryDirectory() as tmpdirname:
        print("Created temporary directory", tmpdirname)
        fig.savefig(tmpdirname + "/" + outfile, bbox_inches="tight")
        subprocess.run(["pdfcrop", tmpdirname + "/" + outfile, outfile]) 

def make_delta_effect(outfile, titles, deltas, eta, prox_fn, T=1000, omd=True, ylabel=False):
    fig, axs = plt.subplots(1, len(deltas), figsize=(.84*2.7*len(deltas) + (0.1 if ylabel else 0.0),.9*2.4), sharey=True)

    for i, delta in enumerate(deltas):
        A = game_matrix(delta)
        xs, ys, gaps, ts = simulate(A, eta, prox_fn, T=T, omd=omd)
        xs_filt, ys_filt, gaps_filt, ts_filt = filter_data(xs, ys, gaps, ts, tol=0.003)

        for (t,x) in zip(ts, xs):
            if x >= 1/(1+delta):
                T1 = t
                break
        assert T1 is not None
        assert ts[T1] == T1

        for (t, y) in zip(ts, ys):
            if t > T1 and y >= 0.5/(1+delta):
                T2 = t
                break
        assert T2 is not None
        assert ts[T2] == T2

        axs[i].loglog(ts_filt, gaps_filt, 'o-', ms=1, lw=1)
        axs[i].loglog([T1], [gaps[T1]], 'r*', ms=7, mew=.5, mec='black')
        axs[i].loglog([T2], [gaps[T2]], 'o', ms=5.5, mew=.5, mec='black', mfc='blue')
        nticks = 9
        maj_loc = mpl.ticker.LogLocator(numticks=nticks)
        min_loc = mpl.ticker.LogLocator(subs='all', numticks=nticks)
        axs[i].xaxis.set_major_locator(maj_loc)
        axs[i].xaxis.set_minor_locator(min_loc)

        axs[i].grid()
        axs[i].set_xlabel("Iteration")
        axs[i].set_title(titles[i])
        if ylabel and i == 0:
            axs[i].set_ylabel("Equilibrium gap")

    fig.tight_layout()
    with tempfile.TemporaryDirectory() as tmpdirname:
        print("Created temporary directory", tmpdirname)
        fig.savefig(tmpdirname + "/" + outfile, bbox_inches="tight")
        subprocess.run(["pdfcrop", tmpdirname + "/" + outfile, outfile]) 

if __name__ == '__main__':
    make_plot("ftr_ent.pdf", title="Entropy (OMWU)", delta=mpf("1e-2"), eta=mpf(0.25), T=10000, prox_fn=entropy_prox, omd=False, energy_plot=False, ylabel=True, inset=1)
    make_plot("ftr_log.pdf", title="Log regularizer", delta=mpf("1e-2"), eta=mpf(2.0), T=10000, prox_fn=log_prox, omd=False, energy_plot=False)
    make_plot("ftr_euc.pdf", title="Sq. Euclidean norm", delta=mpf("1e-2"), eta=mpf(0.1), T=10000, prox_fn=euclidean_prox, omd=False)
    make_plot("omd_euc.pdf", title="OGDA", delta=mpf("1e-2"), eta=mpf(0.1), T=10000, prox_fn=euclidean_prox, omd=True, inset=2)

    make_delta_effect("delta_omwu.pdf", titles=[r"$\delta=0.05$",r"$\delta=0.01$",r"$\delta=0.005$"], deltas=[mpf("5e-2"), mpf("1e-2"), mpf("5e-3")], eta=mpf(0.1), T=10000, prox_fn=entropy_prox, omd=False, ylabel=True)